/*
Copyright 2008 Flaptor (flaptor.com) 

Licensed under the Apache License, Version 2.0 (the "License"); 
you may not use this file except in compliance with the License. 
You may obtain a copy of the License at 

    http://www.apache.org/licenses/LICENSE-2.0 

Unless required by applicable law or agreed to in writing, software 
distributed under the License is distributed on an "AS IS" BASIS, 
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
See the License for the specific language governing permissions and 
limitations under the License.
*/

package com.flaptor.util;

import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.lang.reflect.Method;
import java.net.ServerSocket;
import java.util.ArrayList;
import java.util.regex.Pattern;

import org.apache.log4j.PropertyConfigurator;



/**
 * This class replaces the JUnit TestCase and provides test threads management, output
 * filtering, unexpected error detection and a mix of small convenience utilities.
 * 
 * By extending this class instead of the standard JUnit TestCase, any error produced 
 * through log4j will fail the test. 
 * 
 * Expected errors should be filtered using the following methods in the test setUp()
 * or anywhere in the test code:
 *
 *   filterOutput(literal)     omits any output that matches the provided string
 *   filterOutputRegex(regex)  omits any output that matches the provided regular expression
 *   unfilterOutput()          stops filtering
 *
 * Threads in the test code should extend TestThread and implement runTest() and kill().
 * This way, any uncatched exception that is thrown within a thread will fail the test.
 *
 * The kill() method will be called when any thread throws an exception. Also, threads 
 * should exit on InterruptedException and isInterrupted() returns true.
 *
 * The startTestThreads() method starts all instantiated TestThreads.
 * The waitForTestThreads() method blocks until all running TestThreads finish.
 * The runTestThreads() method does both.
 *
 * This class also takes care of the Log4J configuration.
 *
 */
public class TestCase extends junit.framework.TestCase {

    private PrintStream stdOut = System.out;
    private PrintStream stdErr = System.err;

    private String fName;

    /**
     * Initialize the class.
     */
    public TestCase () {
        super();
        fName = null;
        unfilterOutput();
    }

    /**
     * Initialize the class.
     */
    public TestCase (String name) {
        super(name);
        fName = name;
        unfilterOutput();
    }

    /**
     * Configure Log4J.
     */
    private void configureLog4J () {
        String log4jConfigPath = com.flaptor.util.FileUtil.getFilePathFromClasspath("log4j.properties");
        if (null != log4jConfigPath) {
            PropertyConfigurator.configureAndWatch(log4jConfigPath);
        } else {
            System.err.println("log4j.properties not found on classpath!");
        }
    }


    // Output filters.


    /**
     * This class implements a stream that can be filtered with a regular expression.
     */
    private class FilterStream extends PrintStream {
        private Pattern pattern;

        public FilterStream (OutputStream out, String regex) {
            super(out);
            pattern = Pattern.compile(regex, Pattern.COMMENTS);
        }

        public void emit (byte[] buf, int off, int len) {
            super.write(buf,off,len);
        }

        public void match (byte[] buf, int off, int len) {
            // do not emit
        }

        public void nomatch (byte[] buf, int off, int len) {
            emit(buf,off,len);
        }

        public void write (byte[] buf, int off, int len) {
            String line = new String(buf);
            int n = line.indexOf('\n');
            if (n < 0) n = line.length();
            if (n == 0) {
                nomatch(buf,off,len);
            } else {
                line = line.substring(0,n);

                if (pattern.matcher(line).find()) {
                    match(buf,off,len);
                } else {
                    nomatch(buf,off,len);
                }
            }
        }
    }

    /**
     * This filter fails the test if it detects an error message.
     */
    private class ErrorWatchStream extends FilterStream {
        public ErrorWatchStream (OutputStream out) {
            super(out, "^(ERROR|FATAL)\\s\\[");
        }
        public void match (byte[] buf, int off, int len) {
            emit(buf,off,len);
            fail("Unexpected error detected!");
        }
    }

    /**
     * Adds a filter to the output stream (using a literal string).
     * @param literal Anything matching this string will not be shown.
     */
    protected void filterOutput(String literal) {
        filterOutputRegex(Pattern.quote(literal));
    }

    /**
     * Adds a filter to the output stream (using a regular expression).
     * @param regex Anything matching this regular expression will not be shown.
     */
    protected void filterOutputRegex(String regex) {
        System.setOut(new FilterStream(new ErrorWatchStream(stdOut), regex));
        System.setErr(new FilterStream(new ErrorWatchStream(stdErr), regex));
        configureLog4J();
    }

    /**
     * Restores the output stream to its unfiltered state (errors are still detected though).
     */
    protected void unfilterOutput() {
        System.setOut(new ErrorWatchStream(stdOut));
        System.setErr(new ErrorWatchStream(stdErr));
        configureLog4J();
    }


    // Tools for handling threads.


    /** 
     * Threads in a test should extend TestThread, implement the runTest method and exit on 
     * interrupt (isInterrupted, InterruptedException, ClosedByInterruptException).
     * Use runTestThreads to run all instantiated TestThreads (it will block until they finish).
     * Exceptions thrown by the TestThread will cause test failure.
     */

    protected ArrayList<TestThread> threads = new ArrayList<TestThread>();
    private junit.framework.TestResult testResult = null;
    protected volatile boolean error = false;

    /** 
     * Overrides the JUnit run method to take care of the test results.
     */
    public void run (final junit.framework.TestResult result) {
        testResult = result;
        super.run(testResult);
    }

    /**
     * Extend this class to implement threads in the test code.
     */ 
    protected abstract class TestThread extends Thread {

        // Register the thread and set it as daemon.
        public TestThread () {
            threads.add(this);
            setDaemon(true);
        }

        // Here is where the test thread does its thing.
        public abstract void runTest () throws Throwable;

        // Here is where the test thread should stop doing its thing.
        public abstract void kill () throws Throwable;

        // We override run so we can save the test result.
        public void run () {
            try {
                runTest();
            } catch (Throwable t) {
                stopThreads();
                handleException(t);
            }
        } 

        // Start the thread if we are not aborting.
        public void start () {
            synchronized (threads) {
                if (!error) {
                    super.start();
                }
            }
        }
    }

    /**
     * Start all registered threads
     */
    protected void startTestThreads () {
        for (TestThread tt : threads) {
            if (tt.getState() == Thread.State.NEW || tt.getState() == Thread.State.TERMINATED) {
                tt.start();
            }
        }
    }

    /**
     * Wait until all registered threads finish.
     */
    protected void waitForTestThreads () {
        for (TestThread tt : threads) {
            if (tt.getState() != Thread.State.NEW && tt.getState() != Thread.State.TERMINATED) {
                try {
                    tt.join();
                } catch (InterruptedException e) { 
                    /* stop waiting */
                } catch (ThreadDeath t) { 
                    /* stop waiting */
                }
            }
        }

    }

    /**
     * Start all registered threads and wait until they finish.
     */
    protected void runTestThreads () {
        startTestThreads();
        waitForTestThreads();
    }

    /**
     * Report an exception to JUint.
     * @param t the exception to report.
     */
    private void handleException (Throwable t) {
        error = true;
        if (t instanceof junit.framework.AssertionFailedError) {
            testResult.addFailure(this, (junit.framework.AssertionFailedError)t);
        } else {
            testResult.addError(this, t);
        }
    }
    
    
    @Override
    public void runBare() throws Throwable {
        assertNotNull("TestCase.fName cannot be null", fName); // Some VMs crash when calling getMethod(null,null);
        Method runMethod= null;
        try {
            // use getMethod to get all public inherited
            // methods. getDeclaredMethods returns all
            // methods of this class but excludes the
            // inherited ones.
            runMethod= getClass().getMethod(fName, (Class[])null);
        } catch (NoSuchMethodException e) {
            fail("Method \""+fName+"\" not found");
        }
        TestInfo info = runMethod.getAnnotation(TestInfo.class);

        if (hasToRun(info)) {
            //This fragment of code is the original (super.runBare)
            Throwable exception= null;
            setUp();
            try {
                runTest();
            } catch (Throwable running) {
                exception= running;
            }
            finally {
                try {
                    tearDown();
                } catch (Throwable tearingDown) {
                    if (exception == null) exception= tearingDown;
                }
            }
            //This part is also new
            try {
                checkCleanExit(info);
            } catch (Throwable checkingExit) {
                exception = checkingExit;
            }
            
            if (exception != null) throw exception;
        }    
    }
    
    
    private void checkCleanExit(TestInfo info) {
        if (null == info) return;
        if (isSomePortInUse(info.requiresPort())) {
            fail("This test did not clean up some of it ports.");
        }
    }
    
    private boolean hasToRun(TestInfo info) {
        String ttr = System.getProperty("TestsToRun");
        if (null != ttr) {
            assert(ttr.equals(TestInfo.TestType.UNIT.toString())
                    || ttr.equals(TestInfo.TestType.INTEGRATION.toString())
                    || ttr.equals(TestInfo.TestType.SYSTEM.toString())
                    || ttr.equals("ALL"));
            if (null != info) {
                return annotatedTestHasToRun(info, ttr);
            } else {
                System.out.println("Warning: test \"" + fName + "\" without test info. Running it anyway.");
                return true;
            }
        } else {
            return true;
        }
    }
    
    private boolean annotatedTestHasToRun(TestInfo info, String runlevel) {
        if (!runlevel.equals("ALL") && !runlevel.equals(info.testType().toString())) {
            System.out.println("Skipping test  \"" + fName + "\": it is not of the type specified.");
            return false;
        }
        return !isSomePortInUse(info.requiresPort());
    }

    private boolean isSomePortInUse(int[] ports) {
        for (int p : ports) {
            ServerSocket s = null;
            try {
                s = new ServerSocket(p);
            } catch (IOException e) {
                System.out.println("Port " + p + " is in use. " + e);
                return true;
            } finally {
                Execute.close(s);
            }
        }
        return false;
    }

    /**
     * Gets the name of a TestCase
     * @return the name of the TestCase
     */
    @Override
    public String getName() {
        return fName;
    }
    /**
     * Sets the name of a TestCase
     * @param name the name to set
     */
    @Override
    public void setName(String name) {
        fName= name;
        super.setName(name);
    }

    /**
     * Do everything possible to stop all running threads.
     */
    @SuppressWarnings("deprecation")
    private void stopThreads () {
        for (TestThread tt : threads) {
            try {
                tt.interrupt();
                tt.kill();
                tt.stop();
            } catch (ThreadDeath t) { 
                /* expected, ignore */
            } catch (Throwable t) { 
                System.err.println("Exception thrown while stopping a TestThread: "+t);
            }
        }
    }

}

